from __future__ import print_function

from multiprocessing import Process, Manager
import multiprocessing
import random
import sys
import threading
from functools import partial, reduce

from optparse import OptionParser

from ioutils_direct import *
from models import *
import os
import tensorflow.contrib.eager as tfe

# 开启 Eager 模式。一旦开启不能撤销！只执行一次。
tfe.enable_eager_execution()


multiprocessing.freeze_support()
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
'''
Script for training the core finder model

Key changes from NIPS paper version:
- Addition of "rich" options for atom featurization with more informative descriptors
- Predicted reactivities are not 1D, but 5D and explicitly identify what the bond order of the product should be
'''

NK = 20
NK0 = 10

parser = OptionParser()
parser.add_option("-t", "--train", dest="train_path")
parser.add_option("-m", "--save_dir", dest="save_path")
parser.add_option("-b", "--batch", dest="batch_size", default=20)
parser.add_option("-w", "--hidden", dest="hidden_size", default=100)
parser.add_option("-d", "--depth", dest="depth", default=1)
parser.add_option("-l", "--max_norm", dest="max_norm", default=5.0)
parser.add_option("-r", "--rich", dest="rich_feat", default=False)
opts, args = parser.parse_args()

'''
batch_size = int(opts.batch_size)
hidden_size = int(opts.hidden_size)
depth = int(opts.depth)
max_norm = float(opts.max_norm)
train_path = opts.train_path
'''
opts.train_path = "/home/firefox/PycharmProjects/rexgen_direct (copy)/rexgen_direct/data/train.txt.proc"
opts.save_path = "/home/firefox/PycharmProjects/rexgen_direct (copy)/rexgen_direct/core_wln_global/mymodel"

batch_size = 1
hidden_size = 300
depth = 3
max_norm = 5.0
# train_path = '../data/train.txt.proc'
# save_path = '../mymodel'
opts.rich_feat = 0
if opts.rich_feat:
    from mol_graph_rich import atom_fdim as adim, bond_fdim as bdim, max_nb, smiles2graph_list as _s2g
else:
    from mol_graph import atom_fdim as adim, bond_fdim as bdim, max_nb, smiles2graph_list as _s2g

smiles2graph_batch = partial(
    _s2g, idxfunc=lambda x: x.GetIntProp('molAtomMapNumber') - 1)

# gpu_options = tf.GPUOptions(allow_growth=True)
# session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
#
# config = tf.ConfigProto(device_count={"CPU": 12},     # limit to num_cpu_core CPU usage
#                 inter_op_parallelism_threads = 1,
#                 intra_op_parallelism_threads = 1,
#                 log_device_placement=True)
#
# session = tf.Session(config=config)


# filename_queue = tf.train.string_input_producer(
#     ['/home/firefox/PycharmProjects/rexgen_direct (copy)/rexgen_direct/data/train.proc50000.tfrecord'])
# reader = tf.TFRecordReader()
# _, serialized_example = reader.read(filename_queue)

def parse_chem_sequence(record):
    features = tf.parse_single_example(record, features={
        "input_atom": tf.VarLenFeature(tf.float32),
        "input_bond": tf.VarLenFeature(tf.float32),
        "atom_graph": tf.VarLenFeature(tf.float32),
        "bond_graph": tf.VarLenFeature(tf.float32),
        "num_nbs": tf.VarLenFeature(tf.float32),
        "node_mask": tf.VarLenFeature(tf.float32),
        "label": tf.VarLenFeature(tf.float32),
        "binary": tf.VarLenFeature(tf.float32),
        "sp_label": tf.VarLenFeature(tf.float32),

        "input_atom_shape": tf.VarLenFeature(tf.float32),
        "input_bond_shape": tf.VarLenFeature(tf.float32),
        "atom_graph_shape": tf.VarLenFeature(tf.float32),
        "bond_graph_shape": tf.VarLenFeature(tf.float32),
        "num_nbs_shape": tf.VarLenFeature(tf.float32),
        "node_mask_shape": tf.VarLenFeature(tf.float32),
        "label_shape": tf.VarLenFeature(tf.float32),
        "binary_shape": tf.VarLenFeature(tf.float32),
        "sp_label_shape": tf.VarLenFeature(tf.float32),
    })

    input_atom = features['input_atom']
    input_bond = features['input_bond']
    atom_graph = features['atom_graph']
    bond_graph = features['bond_graph']
    num_nbs = features['num_nbs']
    node_mask = features['node_mask']
    label = features['label']
    binary = features['binary']
    sp_label = features['sp_label']
    input_atom_shape = features['input_atom_shape']
    input_bond_shape = features['input_bond_shape']
    atom_graph_shape = features['atom_graph_shape']
    bond_graph_shape = features['bond_graph_shape']
    num_nbs_shape = features['num_nbs_shape']
    node_mask_shape = features['node_mask_shape']
    label_shape = features['label_shape']
    binary_shape = features['binary_shape']
    sp_label_shape = features['sp_label_shape']

    input_atom = tf.cast(input_atom, tf.float32)
    input_bond = tf.cast(input_bond, tf.float32)
    atom_graph = tf.cast(atom_graph, tf.int32)
    bond_graph = tf.cast(bond_graph, tf.int32)
    num_nbs = tf.cast(num_nbs, tf.int32)
    node_mask = tf.cast(node_mask, tf.float32)
    label = tf.cast(label, tf.int32)
    binary = tf.cast(binary, tf.float32)
    input_atom_shape = tf.cast(input_atom_shape, tf.int32)
    input_bond_shape = tf.cast(input_bond_shape, tf.int32)
    atom_graph_shape = tf.cast(atom_graph_shape, tf.int32)
    bond_graph_shape = tf.cast(bond_graph_shape, tf.int32)
    num_nbs_shape = tf.cast(num_nbs_shape, tf.int32)
    node_mask_shape = tf.cast(node_mask_shape, tf.int32)
    label_shape = tf.cast(label_shape, tf.int32)
    binary_shape = tf.cast(binary_shape, tf.int32)
    sp_label_shape = tf.cast(sp_label_shape, tf.int32)

    input_atom = tf.sparse_tensor_to_dense(input_atom)
    input_bond = tf.sparse_tensor_to_dense(input_bond)
    atom_graph = tf.sparse_tensor_to_dense(atom_graph)
    bond_graph = tf.sparse_tensor_to_dense(bond_graph)
    num_nbs = tf.sparse_tensor_to_dense(num_nbs)
    node_mask = tf.sparse_tensor_to_dense(node_mask)
    label = tf.sparse_tensor_to_dense(label)
    binary = tf.sparse_tensor_to_dense(binary)
    sp_label = tf.sparse_tensor_to_dense(sp_label)
    binary_shape = tf.sparse_tensor_to_dense(binary_shape)
    # sp_label = sp_label.eval(session=session)

    input_atom = tf.reshape(input_atom, [-1, adim])
    input_bond = tf.reshape(input_bond, [-1, bdim])
    atom_graph = tf.reshape(atom_graph, [-1, max_nb])   # max_nb == 10
    bond_graph = tf.reshape(bond_graph, [-1, max_nb])
    num_nbs = tf.reshape(num_nbs, [-1])
    node_mask = tf.reshape(node_mask, [-1])
    label = tf.reshape(label, [-1])
    binary = tf.reshape(binary, binary_shape[1:])

    input_atom.set_shape([None, adim])
    input_bond.set_shape([None, bdim])
    atom_graph.set_shape([None, max_nb])
    bond_graph.set_shape([None, max_nb])
    num_nbs.set_shape([None])
    node_mask.set_shape([None])
    label.set_shape([None])
    binary.set_shape([None, None, binary_fdim])
    graph_inputs = (input_atom, input_bond, atom_graph,
                    bond_graph, num_nbs, node_mask)
    input_data = (graph_inputs, binary)
    target = (label, sp_label)
    # return input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask, label, binary, sp_label
    return input_data, target


training_data = tf.data.TFRecordDataset(
    '/home/firefox/PycharmProjects/rexgen_direct/rexgen_direct/data/train.proc100.tfrecord')
# dataset = dataset.map(parse_chem_sequence).shuffle(buffer_size=10000)
# dataset = dataset.padded_batch(batch_size, padded_shapes=([None],[],[]))
training_data = training_data.map(parse_chem_sequence)
# training_data = training_data.batch(1)

eval_data = tf.data.TFRecordDataset(
    '/home/firefox/PycharmProjects/rexgen_direct/rexgen_direct/data/valid.proc100.tfrecord')
eval_data = eval_data.map(parse_chem_sequence)
# eval_data = eval_data.batch(1)


class Linear(tf.keras.layers.Layer):

    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        stddev = min(1.0 / math.sqrt(input_shape[-1]), 0.1)
        w_init = tf.random_normal_initializer(stddev=stddev)
        self.w = tf.Variable(initial_value=w_init(shape=(input_shape[-1], self.units),
                                                  dtype='float32'),
                             trainable=True)
        self.b = self.add_weight(shape=(self.units,),
                                 initializer='random_normal',
                                 trainable=True)

    def call(self, inputs, **kwargs):
        return tf.matmul(inputs, self.w) + self.b


class linearND(tf.keras.layers.Layer):

    def __init__(self, units=300, init_bias=None):
        super(linearND, self).__init__()
        self.units = units
        self.init_bias = init_bias

    def build(self, input_shape):
        stddev = min(1.0 / math.sqrt(int(input_shape[-1])), 0.1)
        target_shape = tf.concat([input_shape[1:], [self.units]], 0)
        w_init = tf.random_normal_initializer(stddev=stddev)
        self.w = tf.Variable(initial_value=w_init(shape=(input_shape[-1], self.units),
                                                  dtype='float32'),
                             trainable=True)
        if self.init_bias is not None:
            self.b = self.add_weight(shape=(self.units,),
                                     initializer='random_normal',
                                     trainable=True)

    def call(self, inputs, **kwargs):
        input_shape = inputs.get_shape().as_list()
        ndim = len(input_shape)
        X_shape = tf.gather(tf.shape(inputs), list(range(ndim - 1)))
        target_shape = tf.concat([X_shape, [self.units]], 0)
        exp_input = tf.reshape(inputs, [-1, input_shape[-1]])
        if self.init_bias is None:
            res = tf.matmul(exp_input, self.w)
        else:
            res = tf.matmul(exp_input, self.w) + self.b
        res = tf.reshape(res, target_shape)
        res.set_shape(input_shape[:-1] + [self.units])
        return res


class Rcnn_Wl_Last(tf.keras.layers.Layer):
    '''This function performs the WLN embedding (local, no attention mechanism)'''

    def __init__(self, hidden_size, depth=3):
        super(Rcnn_Wl_Last, self).__init__()
        self.linearND_atom_embedding = linearND(
            units=hidden_size, init_bias=None)  # hidden_size is default to 300
        self.linearND_nei_atom = linearND(units=hidden_size, init_bias=None)
        self.linearND_nei_bond = linearND(units=hidden_size, init_bias=None)
        self.linearND_self_atom = linearND(units=hidden_size, init_bias=None)
        self.linearND_label_U2 = linearND(units=hidden_size, init_bias=0)
        self.linearND_label_U1 = linearND(units=hidden_size, init_bias=0)
        self.depth = depth

    def build(self, input_shape):
        pass

    def call(self, inputs, **kwargs):
        graph_inputs = inputs
        input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs
        batch_size = input_atom.get_shape().as_list()[0]
        atom_features = tf.nn.relu(self.linearND_atom_embedding(input_atom))
        node_mask = tf.reshape(node_mask, (tf.shape(node_mask)[0], -1, 1))
        node_mask = tf.tile(node_mask, [1, 1, 300])
        layers = []

        for i in range(self.depth):
            fatom_nei = tf.gather_nd(atom_features, atom_graph)
            fbond_nei = tf.gather_nd(input_bond, bond_graph)
            h_nei_atom = self.linearND_nei_atom(fatom_nei)
            h_nei_bond = self.linearND_nei_bond(fbond_nei)
            h_nei = h_nei_atom * h_nei_bond
            mask_nei = tf.reshape(tf.sequence_mask(tf.reshape(num_nbs, [-1]), max_nb, dtype=tf.float32),
                                  [batch_size, -1, max_nb, 1])
            f_nei = tf.reduce_sum(h_nei * mask_nei, -2)
            f_self = self.linearND_self_atom(atom_features)
            try:
                layers.append(f_nei * f_self * node_mask)   # output
            except:
                import pdb
                pdb.set_trace()
            l_nei = tf.concat([fatom_nei, fbond_nei], 3)
            nei_label = tf.nn.relu(self.linearND_label_U2(l_nei))
            nei_label = tf.reduce_sum(nei_label * mask_nei, -2)
            new_label = tf.concat([atom_features, nei_label], 2)
            new_label = self.linearND_label_U1(new_label)
            atom_features = tf.nn.relu(new_label)   # update atom features
        # kernels = tf.concat(1, layers)
        # atom FPs are the final output after "depth" convolutions
        kernels = layers[-1]
        fp = tf.reduce_sum(kernels, 1)  # molecular FP is sum over atom FPs
        return kernels, fp


class chemFormularGCN(tf.keras.Model):
    def __init__(self, device='cpu:0',
                 checkpoint_directory='/home/firefox/PycharmProjects/rexgen_direct/rexgen_direct/core_wln_global/mymodel'):
        super(chemFormularGCN, self).__init__()
        self.device = device
        self.checkpoint_directory = checkpoint_directory
        self.batch_size = 1
        self.history = {'train_loss': [], 'eval_loss': []}

        self.rcnn_wl_last = Rcnn_Wl_Last(hidden_size=300, depth=3)
        self.linearND_att_atom_feature = linearND(
            units=hidden_size, init_bias=None)
        self.linearND_att_bin_feature = linearND(
            units=hidden_size, init_bias=0)
        self.linearND_att_scores = linearND(units=1, init_bias=None)
        self.linearND_atom_feature = linearND(
            units=hidden_size, init_bias=None)
        self.linearND_bin_feature = linearND(units=hidden_size, init_bias=None)
        self.linearND_ctx_feature = linearND(units=hidden_size, init_bias=None)
        self.linearND_scores = linearND(units=5, init_bias=None)

    def predict(self, input_data, bmask):
        graph_inputs, binary = input_data
        input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs
        node_mask = tf.expand_dims(node_mask, -1)

        # Perform the WLN embedding
        atom_hiddens, _ = self.rcnn_wl_last(graph_inputs)

        # Calculate local atom pair features as sum of local atom features
        atom_hiddens1 = tf.reshape(
            atom_hiddens, [batch_size, 1, -1, hidden_size])
        atom_hiddens2 = tf.reshape(
            atom_hiddens, [batch_size, -1, 1, hidden_size])
        atom_pair = atom_hiddens1 + atom_hiddens2

        # Calculate attention scores for each pair o atoms
        att_hidden = tf.nn.relu(self.linearND_att_atom_feature(
            atom_pair) + self.linearND_att_bin_feature(binary))
        att_score = self.linearND_att_scores(att_hidden)
        att_score = tf.nn.sigmoid(att_score)

        # Calculate context features using those attention scores
        att_context = att_score * atom_hiddens1
        att_context = tf.reduce_sum(att_context, 2)

        # Calculate global atom pair features as sum of atom context features
        att_context1 = tf.reshape(
            att_context, [batch_size, 1, -1, hidden_size])
        att_context2 = tf.reshape(
            att_context, [batch_size, -1, 1, hidden_size])
        att_pair = att_context1 + att_context2

        # Calculate likelihood of each pair of atoms to form a particular bond order
        pair_hidden = self.linearND_atom_feature(atom_pair) + self.linearND_bin_feature(binary) + \
            self.linearND_ctx_feature(att_pair)

        pair_hidden = tf.nn.relu(pair_hidden)
        pair_hidden = tf.reshape(pair_hidden, [batch_size, -1, hidden_size])
        score = self.linearND_scores(pair_hidden)
        score = tf.reshape(score, [batch_size, -1])

        topk_scores, topk = tf.nn.top_k(score - bmask, k=NK)
        flat_score = tf.reshape(score, [-1])
        return topk, topk_scores, att_score, flat_score

    def loss_fn(self, input_data, target):
        """ Defines the loss function used during
            training.
        """
        label, _ = target
        flat_label = tf.reshape(label, [-1])
        bond_mask = tf.to_float(tf.not_equal(flat_label, INVALID_BOND))
        flat_label = tf.maximum(0, flat_label)
        bmask = tf.to_float(tf.equal(label, INVALID_BOND)) * 10000
        _, _, _, flat_score = self.predict(input_data, bmask)
        import pdb
        pdb.set_trace()
        # Train with categorical crossentropy
        loss = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=flat_score, labels=tf.to_float(flat_label))
        loss = tf.reduce_sum(loss * bond_mask)
        loss = loss / self.batch_size
        return loss

    def grads_fn(self, input_data, target):
        """ Dynamically computes the gradients of the loss value
            with respect to the parameters of the model, in each
            forward pass.
        """
        with tfe.GradientTape() as tape:
            loss = self.loss_fn(input_data, target)
        grads = tape.gradient(loss, model.trainable_variables)
        new_grads, _ = tf.clip_by_global_norm(grads, max_norm)
        return new_grads

    def restore_model(self):
        """ Function to restore trained model.
        """
        with tf.device(self.device):
            # # Run the model once to initialize variables
            # dummy_input = tf.constant(tf.zeros((1, 48, 48, 1)))
            # dummy_pred = self.predict(dummy_input, training=False)
            # Restore the variables of the model
            saver = tfe.Saver(self.variables)
            saver.restore(tf.train.latest_checkpoint
                          (self.checkpoint_directory))

    def save_model(self, global_step=0):
        """ Function to save trained model.
        """
        tfe.Saver(self.variables).save(self.checkpoint_directory,
                                       global_step=global_step)

    def fit(self, training_data, eval_data, optimizer, batch_size=1, num_epochs=5,
            early_stopping_rounds=10, verbose=1, train_from_scratch=False):
        """ Function to train the model, using the selected optimizer and
            for the desired number of epochs. You can either train from scratch
            or load the latest model trained. Early stopping is used in order to
            mitigate the risk of overfitting the network.

            Args:
                training_data: the data you would like to train the model on.
                                Must be in the tf.data.Dataset format.
                eval_data: the data you would like to evaluate the model on.
                            Must be in the tf.data.Dataset format.
                optimizer: the optimizer used during training.
                num_epochs: the maximum number of iterations you would like to
                            train the model.
                early_stopping_rounds: stop training if the loss on the eval
                                       dataset does not decrease after n epochs.
                verbose: int. Specify how often to print the loss value of the network.
                train_from_scratch: boolean. Whether to initialize variables of the
                                    the last trained model or initialize them
                                    randomly.
        """
        self.batch_size = batch_size

        def pack2D(arr_list):
            N = max([x.shape[0] for x in arr_list])
            M = max([x.shape[1] for x in arr_list])
            a = np.zeros((len(arr_list), N, M))
            for i, arr in enumerate(arr_list):
                n = arr.shape[0]
                m = arr.shape[1]
                a[i, 0:n, 0:m] = arr
            return a

        def pack2D_withidx(arr_list):
            N = max([x.shape[0] for x in arr_list])
            M = max([x.shape[1] for x in arr_list])
            a = np.zeros((len(arr_list), N, M, 2))
            for i, arr in enumerate(arr_list):
                n = arr.shape[0]
                m = arr.shape[1]
                a[i, 0:n, 0:m, 0] = i
                a[i, 0:n, 0:m, 1] = arr
            return a

        def pack1D(arr_list):
            N = max([x.shape[0] for x in arr_list])
            a = np.zeros((len(arr_list), N))
            for i, arr in enumerate(arr_list):
                n = arr.shape[0]
                a[i, 0:n] = arr
            return a

        def get_mask(arr_list):
            N = max([x.shape[0] for x in arr_list])
            a = np.zeros((len(arr_list), N))
            for i, arr in enumerate(arr_list):
                for j in range(arr.shape[0]):
                    a[i][j] = 1
            return a

        if not train_from_scratch:
            self.restore_model()

        # Initialize best loss. This variable will store the lowest loss on the
        # eval dataset.
        best_loss = 999

        # Initialize classes to update the mean loss of train and eval
        train_loss = tfe.metrics.Mean('train_loss')
        eval_loss = tfe.metrics.Mean('eval_loss')

        # Initialize dictionary to store the loss history
        self.history = {'train_loss': [], 'eval_loss': []}

        # Begin training
        with tf.device(self.device):
            for i in range(num_epochs):
                # Training with gradient descent
                batch_count = 0
                batch_data = {'binary': [], 'fatom_list': [], 'fbond_list': [], 'gatom_list': [], 'gbond_list': [],
                              'nb_list': []}
                for input_data, target in tfe.Iterator(training_data):
                    batch_count += 1
                    graph_inputs, binary = input_data
                    input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs
                    batch_data['fatom_list'].append(input_atom)
                    batch_data['fbond_list'].append(input_bond)
                    batch_data['gatom_list'].append(atom_graph)
                    batch_data['gbond_list'].append(bond_graph)
                    batch_data['nb_list'].append(num_nbs)
                    batch_data['binary'].append(binary)

                    if batch_count == batch_size:
                        graph_inputs = [tf.convert_to_tensor(pack2D(batch_data['fatom_list']), dtype=tf.float32),
                                        tf.convert_to_tensor(
                                            pack2D(batch_data['fbond_list']), dtype=tf.float32),
                                        tf.convert_to_tensor(pack2D_withidx(
                                            batch_data['gatom_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(pack2D_withidx(
                                            batch_data['gbond_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(
                                            pack1D(batch_data['nb_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(get_mask(batch_data['fatom_list']), dtype=tf.float32)]
                        binary = tf.convert_to_tensor(
                            batch_data['binary'], dtype=tf.float32)
                        input_data = [graph_inputs, binary]
                        grads = self.grads_fn(input_data, target)
                        optimizer.apply_gradients(zip(grads, model.variables))
                        batch_data = {'binary': [], 'fatom_list': [], 'fbond_list': [], 'gatom_list': [],
                                      'gbond_list': [], 'nb_list': []}
                        batch_count = 0

                # Compute the loss on the training data after one epoch
                batch_data = {'binary': [], 'fatom_list': [], 'fbond_list': [], 'gatom_list': [], 'gbond_list': [],
                              'nb_list': []}
                batch_count = 0
                for input_data, target in tfe.Iterator(training_data):
                    batch_count += 1
                    graph_inputs, binary = input_data
                    input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs
                    batch_data['fatom_list'].append(input_atom)
                    batch_data['fbond_list'].append(input_bond)
                    batch_data['gatom_list'].append(atom_graph)
                    batch_data['gbond_list'].append(bond_graph)
                    batch_data['nb_list'].append((num_nbs))
                    batch_data['binary'].append(binary)

                    if batch_count == batch_size:
                        graph_inputs = [tf.convert_to_tensor(pack2D(batch_data['fatom_list']), dtype=tf.float32),
                                        tf.convert_to_tensor(
                                            pack2D(batch_data['fbond_list']), dtype=tf.float32),
                                        tf.convert_to_tensor(pack2D_withidx(
                                            batch_data['gatom_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(pack2D_withidx(
                                            batch_data['gbond_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(
                                            pack1D(batch_data['nb_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(get_mask(batch_data['fatom_list']), dtype=tf.float32)]
                        binary = tf.convert_to_tensor(
                            batch_data['binary'], dtype=tf.float32)
                        input_data = [graph_inputs, binary]
                        loss = self.loss_fn(input_data, target)
                        train_loss(loss)
                        batch_data = {'binary': [], 'fatom_list': [], 'fbond_list': [], 'gatom_list': [],
                                      'gbond_list': [], 'nb_list': []}
                        batch_count = 0
                self.history['train_loss'].append(train_loss.result().numpy())
                # Reset metrics
                train_loss.init_variables()

                # Compute the loss on the eval data after one epoch
                batch_data = {'binary': [], 'fatom_list': [], 'fbond_list': [], 'gatom_list': [], 'gbond_list': [],
                              'nb_list': []}
                batch_count = 0
                for input_data, target in tfe.Iterator(eval_data):
                    batch_count += 1
                    graph_inputs, binary = input_data
                    input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs
                    batch_data['fatom_list'].append(input_atom)
                    batch_data['fbond_list'].append(input_bond)
                    batch_data['gatom_list'].append(atom_graph)
                    batch_data['gbond_list'].append(bond_graph)
                    batch_data['nb_list'].append((num_nbs))
                    batch_data['binary'].append(binary)

                    if batch_count == batch_size:
                        graph_inputs = [tf.convert_to_tensor(pack2D(batch_data['fatom_list']), dtype=tf.float32),
                                        tf.convert_to_tensor(
                                            pack2D(batch_data['fbond_list']), dtype=tf.float32),
                                        tf.convert_to_tensor(pack2D_withidx(
                                            batch_data['gatom_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(pack2D_withidx(
                                            batch_data['gbond_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(
                                            pack1D(batch_data['nb_list']), dtype=tf.int32),
                                        tf.convert_to_tensor(get_mask(batch_data['fatom_list']), dtype=tf.float32)]
                        binary = tf.convert_to_tensor(
                            batch_data['binary'], dtype=tf.float32)
                        input_data = [graph_inputs, binary]
                        loss = self.loss_fn(input_data, target)
                        eval_loss(loss)
                        batch_data = {'binary': [], 'fatom_list': [], 'fbond_list': [], 'gatom_list': [],
                                      'gbond_list': [], 'nb_list': []}
                        batch_count = 0

                try:
                    self.history['eval_loss'].append(
                        eval_loss.result().numpy())
                except:
                    import pdb
                    pdb.set_trace()
                # Reset metrics
                eval_loss.init_variables()

                # Print train and eval losses
                if (i == 0) | ((i + 1) % verbose == 0):
                    print('Train loss at epoch %d: ' %
                          (i + 1), self.history['train_loss'][-1])
                    print('Eval loss at epoch %d: ' %
                          (i + 1), self.history['eval_loss'][-1])

                # Check for early stopping
                if self.history['eval_loss'][-1] < best_loss:
                    best_loss = self.history['eval_loss'][-1]
                    count = early_stopping_rounds
                else:
                    count -= 1
                if count == 0:
                    break


model = chemFormularGCN()
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
model.fit(training_data, eval_data, optimizer, batch_size=1, num_epochs=5,
          early_stopping_rounds=10, verbose=1, train_from_scratch=True)

# # Use Adam with clipped gradients
# optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
# size_func = lambda v: reduce(lambda x, y: x * y, v.get_shape().as_list())
# n = sum(size_func(v) for v in tf.trainable_variables())
# print("Model size: %dK" % (n / 1000,))
#
#
# def count(s):
#     c = 0
#     for i in range(len(s)):
#         if s[i] == ':':
#             c += 1
#     return c
#
#
# saver = tf.train.Saver(max_to_keep=None)
# it, sum_acc, sum_err, sum_gnorm = 0, 0.0, 0.0, 0.0
# lr = 0.001
# try:
#     train_writer = tf.summary.FileWriter(opts.save_path + '/train_1', session.graph)
#     while it < 50000:
#         it += 1
#         # Run one minibatch
#         # input_atom, input_bond, atom_graph, bond_graph, num_nbs,node_mask,label,binary = read_data()
#         # src_tuple = (input_atom, input_bond, atom_graph, bond_graph, num_nbs,node_mask)
#         # feed_map = {x: y for x, y in zip(_src_holder, src_tuple)}
#         # feed_map.update({_label: label, _binary: binary})
#         # session.run(enqueue, feed_dict=feed_map)
#         # feed_dict =dict()
#         # feed_dict[_input_atom] = input_atom
#         # feed_dict[_input_bond] = input_bond
#         # feed_dict[_atom_graph] = atom_graph
#         # feed_dict[_bond_graph] = bond_graph
#         # feed_dict[_num_nbs] = num_nbs
#         # feed_dict[_node_mask] = node_mask
#         # feed_dict[_label] = label
#         # feed_dict[_binary] = binary
#
#         _, cur_topk, pnorm, gnorm, loss_eval, sp_label_eval = session.run(
#             [backprop, topk, param_norm, grad_norm, loss, sp_label], feed_dict={_lr: lr})
#
#         sp_label_eval = [sp_label_eval]
#         # Get performance
#         for i in range(batch_size):
#             pre = 0
#             for j in range(1):
#                 if cur_topk[i, j] in sp_label_eval[i]:
#                     pre += 1
#             if len(sp_label_eval[i]) == pre: sum_err += 1
#             pre = 0
#             for j in range(NK0):
#                 if cur_topk[i, j] in sp_label_eval[i]:
#                     pre += 1
#             if len(sp_label_eval[i]) == pre: sum_acc += 1
#         sum_gnorm += gnorm
#
#         if it % 50 == 0:
#             # train_writer.add_summary(pnorm,it)
#             print(it, " Acc@10: %.4f, Acc@20: %.4f, Param Norm: %.2f, Grad Norm: %.2f, loss %.4f" % (
#                 sum_acc / (50 * batch_size), sum_err / (50 * batch_size), pnorm, sum_gnorm / 50, loss_eval))
#             sys.stdout.flush()
#             sum_acc, sum_err, sum_gnorm = 0.0, 0.0, 0.0
#         # if it % 10000 == 0:
#         if it % 10000 == 0:
#             lr *= 0.9
#             saver.save(session, opts.save_path + "/model50000.ckpt", global_step=it)
#             print("Model Saved!")
# except Exception as e:
#     print(e)
# finally:
#     saver.save(session, opts.save_path + "/model50000.final")
